Skip to content

feat(env): mask all-tool-failed rollouts (SimpleTIR-style void turn masking)#1416

Open
mvanhorn wants to merge 1 commit into
PrimeIntellect-ai:mainfrom
mvanhorn:fix/315-tool-failure-mask-rollouts
Open

feat(env): mask all-tool-failed rollouts (SimpleTIR-style void turn masking)#1416
mvanhorn wants to merge 1 commit into
PrimeIntellect-ai:mainfrom
mvanhorn:fix/315-tool-failure-mask-rollouts

Conversation

@mvanhorn
Copy link
Copy Markdown

@mvanhorn mvanhorn commented May 19, 2026

Summary

verifiers/envs/tool_env.py

  1. Extend ToolEnv.__init__ signature with mask_all_failed_tool_calls: bool = False. Store it as self.mask_all_failed_tool_calls. Append if mask_all_failed_tool_calls: self.add_metric(self.void_turn_rollouts_metric).

  2. Add a new method that records outcomes inside env_response. Diff against the existing function (only the two tool_messages.append(...) sites change — one for success, one for the exception path):

async def env_response(self, messages, state, **kwargs):
    last_msg = cast(vf.AssistantMessage, messages[-1])
    assert last_msg.tool_calls is not None
    outcomes = state.setdefault("tool_call_outcomes", [])
    tool_messages = []
    for tool_call in last_msg.tool_calls:
        tool_call_id = tool_call.id
        try:
            tool_name = tool_call.name
            tool_args = json.loads(tool_call.arguments)
        except Exception as e:
            if self._should_stop_for_error(e):
                raise vf.ToolParseError from e
            outcomes.append("error")
            tool_messages.append(ToolMessage(role="tool",
                content=self.error_formatter(e), tool_call_id=tool_call_id))
            continue
        try:
            tool_message = await self.call_tool(tool_name, tool_args, tool_call_id)
            outcomes.append("ok")
            tool_messages.append(tool_message)
        except Exception as e:
            if self._should_stop_for_error(e):
                raise vf.ToolCallError from e
            outcomes.append("error")
            tool_messages.append(ToolMessage(role="tool",
                content=self.error_formatter(e), tool_call_id=tool_call_id))
    return tool_messages
  1. Add a _finalize_state hook that runs after the rollout loop — set state["masked"] when the flag is on and all outcomes are "error". Override the base Environment.rollout post-step OR add this logic at the start of score_objects via the rubric (decided in step 4). For minimal blast radius, do it inside ToolEnv via overriding _post_rollout_hook if it exists, otherwise compute lazily in the metric callable + in the base rubric mask check.

Concretely: add a method on ToolEnv:

def _should_mask(self, state) -> bool:
    outcomes = state.get("tool_call_outcomes") or []
    return (
        self.mask_all_failed_tool_calls
        and len(outcomes) > 0
        and all(o == "error" for o in outcomes)
    )

Wire this into the existing rollout finalization by setting state["masked"] = self._should_mask(state) after the MultiTurnEnv.rollout loop completes. In verifiers/envs/multiturn_env.py::MultiTurnEnv.rollout there is already a finalization section — ToolEnv overrides nothing today, so add async def rollout(self, *args, **kwargs): that calls super().rollout(...) and post-processes state["masked"]. (Confirm signature by reading multiturn_env.py before editing.)

  1. Add the metric callable:
async def void_turn_rollouts_metric(self, state: vf.State) -> float:
    return 1.0 if state.get("masked") else 0.0

verifiers/rubrics/rubric.py

In Rubric.score_objects(state), before computing rewards, short-circuit when state.get("masked") is truthy:

def score_objects(self, state: State) -> dict[str, Any]:
    if state.get("masked"):
        # Preserve reward-func key shape for stable downstream schema.
        return {
            "reward": 0.0,
            "masked": True,
            **{name: 0.0 for name in self._get_individual_reward_func_names()},
        }
    # ...existing body...

This keeps the JSON output schema stable (downstream parsers will not see missing keys) and sets a single explicit zero reward.

Docs

Add a short paragraph to verifiers/envs/AGENTS.md under "Optional flags" describing the new flag and pointing at the SimpleTIR reference.

Why this matters

Issue #315 (filed by @faresobeid, COLLABORATOR) asks for an option to "mask rollouts where all tool calls failed," referencing the SimpleTIR paper (arXiv:2509.02479). In RL training, a trajectory in which every tool invocation failed contributes near-zero signal and can destabilize the gradient. SimpleTIR's contribution is to skip ("mask") these void turns when computing rewards. The verifiers ToolEnv and StatefulToolEnv already track tool calls, catch exceptions, and reply with an error_formattered ToolMessage — but they do not surface a "this rollout produced only failed tool calls" signal that downstream consumers (rubrics, trainers) can use.

Acceptance:

  • ToolEnv (and by inheritance StatefulToolEnv) tracks per-rollout tool call outcomes in state as state["tool_call_outcomes"]: list[Literal["ok","error"]].
  • ToolEnv gains a constructor flag mask_all_failed_tool_calls: bool = False (default off — opt-in, non-breaking).
  • When the flag is on AND every recorded tool call outcome is "error" (and len(outcomes) > 0), the rollout state is marked state["masked"] = True and the rubric returns reward = 0.0 with no per-reward-func contributions, AND a new boolean metric void_turn_rollouts is exposed via add_metric.
  • Existing rollouts (without the flag) behave identically. No public API removed or renamed.
  • Tests cover: (a) flag off → unchanged behavior, (b) flag on + all errors → masked + reward 0, (c) flag on + mixed outcomes → unmasked, (d) flag on + no tool calls → unmasked.

Testing

tests/envs/test_tool_env_void_mask.py (new)

import pytest

import verifiers as vf
from verifiers.envs.tool_env import ToolEnv


def make_env(*, mask: bool):
    def good_tool() -> str:
        return "ok"

    def bad_tool() -> str:
        raise RuntimeError("boom")

    return ToolEnv(
        tools=[good_tool, bad_tool],
        mask_all_failed_tool_calls=mask,
        # ... minimal dataset / rubric fixture
    )


@pytest.mark.asyncio
async def test_flag_off_no_mask_key(monkeypatch):
    env = make_env(mask=False)
    state = await run_two_failing_tool_calls(env)
    assert state.get("masked") in (None, False)


@pytest.mark.asyncio
async def test_flag_on_all_errors_masked(monkeypatch):
    env = make_env(mask=True)
    state = await run_two_failing_tool_calls(env)
    assert state["masked"] is True
    assert state["tool_call_outcomes"] == ["error", "error"]


@pytest.mark.asyncio
async def test_flag_on_mixed_outcomes_unmasked(monkeypatch):
    env = make_env(mask=True)
    state = await run_one_good_one_bad(env)
    assert state["masked"] is False
    assert state["tool_call_outcomes"] == ["ok", "error"]


@pytest.mark.asyncio
async def test_flag_on_no_tool_calls_unmasked(monkeypatch):
    env = make_env(mask=True)
    state = await run_assistant_only_no_tools(env)
    # Empty outcomes list MUST NOT trigger mask
    assert state.get("masked") in (None, False)


@pytest.mark.asyncio
async def test_masked_rollout_scores_zero():
    env = make_env(mask=True)
    state = await run_two_failing_tool_calls(env)
    scored = env.rubric.score_objects(state)
    assert scored["reward"] == 0.0
    assert scored["masked"] is True

run_two_failing_tool_calls / run_one_good_one_bad / run_assistant_only_no_tools are async helpers that drive a synthetic single rollout through env.env_response() directly with hand-crafted AssistantMessage containing ToolCall objects, mirroring the style of existing tests under tests/envs/.

Run with uv run pytest tests/envs/test_tool_env_void_mask.py -v.

Fixes #315

AI was used for assistance.


Note

Medium Risk
Opt-in masking changes reward computation (zeroing rewards/metrics) and adds new state fields that downstream training/eval pipelines may implicitly rely on; behavior is gated by a new flag but touches core env/rubric scoring paths.

Overview
Adds mask_all_failed_tool_calls to ToolEnv (and thus StatefulToolEnv) to record per-tool-call outcomes in state["tool_call_outcomes"], set state["masked"] when at least one tool was called and all outcomes are errors, and emit a new void_turn_rollouts metric.

Updates Rubric.score_rollout and Rubric.score_group to short-circuit masked rollouts to zero reward (and zero out non-metric function contributions), while keeping metric-weight (weight==0) functions evaluable. Adds tests covering flag on/off, mixed/no tool calls, and StatefulToolEnv behavior, plus docs in AGENTS.md describing the new flag and masking semantics.

Reviewed by Cursor Bugbot for commit d54ffee. Bugbot is set up for automated code reviews on this repo. Configure here.

Note

Add mask_all_failed_tool_calls flag to ToolEnv for SimpleTIR-style void turn masking

  • Adds a mask_all_failed_tool_calls constructor parameter to ToolEnv and StatefulToolEnv; when enabled, rollouts where every tool call fails are marked state['masked'] = True.
  • Per-call outcomes ('ok'/'error') are recorded in state['tool_call_outcomes'] during env_response for both env types.
  • Masked rollouts receive a reward of 0.0 and zero out non-zero-weight metrics in Rubric.score_rollout and score_rollouts; zero-weight metrics still execute.
  • Exposes a void_turn_rollouts metric (1.0 when masked, 0.0 otherwise) when the flag is enabled.
  • Behavioral Change: rollouts with all-failed tool calls now return reward 0.0 instead of their computed reward when mask_all_failed_tool_calls=True.

Macroscope summarized d54ffee.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Implement SimpleTIR void turn masking

1 participant